Pytorch 1 Sampler in DataLoader Ⅰ

Pytorch 1 Sampler in DataLoader Ⅰ

三月 02, 2020

cover

最近用到Lovasz-Softmax Loss(作者利用Lovasz Extension将转化为Submodular Set Functions的离散的mIoU/Jaccard Loss拓展到了连续空间中,从而使得其可优化)。在使用mini-batch训练时,为了使得局部的batch-mIoU能够接近全局的dataset-mIoU,作者使用了一种被称为equibatch的trick,使得训练结果有一定程度的提升:

We propose an additional trick for the optimization of the dataset–mIoU. Since the mIoU gives equal importance to each class, and to make the expectation of the batch–mIoU closer to the dataset–mIoU, it seems important to ensure that we feed the network with samples from all classes during training. In order to enforce this requirement, we sample the patches from the training by cycling over every classes, such that each class is visited at least once every |C| patches. This method is referred to as equibatch in our experiments.

由于作者是个重度拖延症,忘记了填这个坑(issue #19),所以只能根据提示去理解源码,自己实现了。

DataLoader for Starters

Initiation

CLASS

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)[SOURCE]

Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.

The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.

Parameters

  • dataset (Dataset) – dataset from which to load the data.
  • batch_size (python:int, optional) – how many samples per batch to load (default: 1).
  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
  • sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified, shuffle must be False.
  • batch_sampler (Sampler, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
  • num_workers (python:int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
  • collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
  • pin_memory (bool, optional) – If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
  • timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
  • worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

DataLoaderPytorch中最基础的函数之一,主要用于处理数据集的处理,将输入的数据集按照shufflesampler的方式返回一个包含Tensor数据的迭代器,用于后续训练。

不考虑IterableDataset类型的DatasetDataLoader在初始化时需做以下判断,这决定了其基本的数据处理方式:

  • sampler/batch_sampler被指定时:

    1. shufflesampler不能共存:

    shuffle实质上被隐形包含于sampler之中,可以说shuffle本身可以代表一种sampler,因此当sampler被指定时,不可将shuffleTrue

    1
    2
    if sampler is not None and shuffle:
    raise ValueError('sampler option is mutually exclusive with shuffle')
    1. batch_sizeshufflesamplerdrop_lastbatch_sampler不能共存:

    batch_sizeshufflesamplerdrop_lastbatch_sampler等参数都被包含于batch_sampler中。

    1
    2
    3
    4
    5
    6
    7
    if batch_sampler is not None:
    # auto_collation with custom batch_sampler
    if batch_size != 1 or shuffle or sampler is not None or drop_last:
    raise ValueError('batch_sampler option is mutually exclusive '
    'with batch_size, shuffle, sampler, and drop_last')
    batch_size = None
    drop_last = False
    1. no auto_collation
    1
    2
    3
    4
    5
    elif batch_size is None:
    # no auto_collation
    if shuffle or drop_last:
    raise ValueError('batch_size=None option disables auto-batching '
    'and is mutually exclusive with shuffle, and drop_last')
  • sampler未被指定时:

    此时根据是否shuffle实例化相应的sampler,所以实质上DataLoader的数据分配永远是通过不同类型的sampler(包括custom sampler)内部的Iterator实现的:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    if sampler is None:  # give default samplers
    if self._dataset_kind == _DatasetKind.Iterable:
    # See NOTE [ Custom Samplers and IterableDataset ]
    sampler = _InfiniteConstantSampler()
    else: # map-style
    if shuffle:
    sampler = RandomSampler(dataset)
    else:
    sampler = SequentialSampler(dataset)
  • batch_sampler未被指定且batch_size存在:

    直接将batch_sampler实例化为BatchSampler

    1
    2
    3
    if batch_size is not None and batch_sampler is None:
    # auto_collation without custom batch_sampler
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)

最后:

1
2
self.sampler = sampler
self.batch_sampler = batch_sampler

所以DataLoader的初始化逻辑如下:

  1. sampler被指定时,判断冗余参数进行异常处理;
  2. batch_sampler被指定时,判断冗余参数进行异常处理;
  3. sampler未被指定时,根据是否shufflesampler实例化为不同Sampler
  4. batch_sampler未被指定时,在batch_sizeNone的情况下将batch_sampler实例化为BatchSampler

Hence:

  • sampler总是会被实例化为primitive或custom的Sampler
  • batch_sampler
    • 被指定时实例化为指定的primitive或custom的Sampler
    • 未被指定时根据batch_size是否存在决定是否实例化为primitive的BatchSampler

注意,这里还有两个很重要的通过@property装饰的属性:

  • _auto_collation:根据batch_sampler是否存在判断

    1
    2
    3
    @property
    def _auto_collation(self):
    return self.batch_sampler is not None
  • _index_sampler:根据_auto_collation判断使用sampler还是batch_sampler

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    @property
    def _index_sampler(self):
    # The actual sampler used for generating indices for `_DatasetFetcher`
    # (see _utils/fetch.py) to read data at each time. This would be
    # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
    # We can't change `.sampler` and `.batch_sampler` attributes for BC
    # reasons.
    if self._auto_collation:
    return self.batch_sampler
    else:
    return self.sampler

所以,_index_sampler这个对外界开放的实质为Iteratorsampler属性至此才揭开其真正面目。

Transition

DataLoader作为Iterator被遍历时,实质上在执行其__iter__函数:

1
2
3
4
5
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)

函数根据指定分配的进程数决定执行_BaseDataLoaderIter族的两类迭代执行函数。

类作为于的基类,其主要代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class _BaseDataLoaderIter(object):
def __init__(self, loader):
self._dataset = loader.dataset
......
self._index_sampler = loader._index_sampler
......
self._sampler_iter = iter(self._index_sampler)

def __iter__(self):
return self

def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration

def _next_data(self):
raise NotImplementedError

def __next__(self):
data = self._next_data()
......
return data

可见,DataLoader最终决定的sampler作为其_index_sampler属性被赋给了_BaseDataLoaderIter类的_index_sampler属性,并被转换为Iterator类型的_sampler_iter属性,供类内函数_next_index通过next()不断迭代。

此外,_BaseDataLoaderIter类作为Iterator保有的__next__函数通过调用当前未实现的_next_data函数获取具体数据,可见dataset内数据通过_next_index迭代产生的indices被读取并转换为_BaseDataLoaderIter的迭代结果,这个步骤是在_BaseDataLoaderIter子类必须实现的_next_data函数这个hook中完成的。

_SingleProcessDataLoaderIter

1
2
3
4
5
6
7
8
9
10
11
12
13
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
......
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data

对于单进程的DataLoader迭代器而言该hook的实现非常简单——直接调用_next_index即可,之后利用获取的index访问数据集获取具体数据,在执行完成锁页相关操作后返回数据。

_MultiProcessingDataLoaderIter

多进程的迭代器的实现就要复杂的多,使用各种signal完成进程间通信,而且由于:

Terminating multiprocessing logic requires very careful design.

所以很多代码和注释都是用于错误处理与正确有序的Exit。这里没有细看,只是大概梳理一下。

  • 在init过程中,首先完成基本设置,并设置主进程(当前进程)为守护进程,并初始化和进入各子进程:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    ......
    self._send_idx = 0 # idx of the next task to be sent to workers
    self._rcvd_idx = 0 # idx of the next task to be returned in __next__
    # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
    # map: task idx => - (worker_id,) if data isn't fetched (outstanding)
    # \ (worker_id, data) if data is already fetched (out-of-order)
    self._task_info = {}
    self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
    ......
    for i in range(self._num_workers):
    index_queue = multiprocessing_context.Queue()
    # index_queue.cancel_join_thread()
    w = multiprocessing_context.Process(
    target=_utils.worker._worker_loop,
    args=(self._dataset_kind, self._dataset, index_queue,
    self._worker_result_queue, self._workers_done_event,
    self._auto_collation, self._collate_fn, self._drop_last,
    self._base_seed + i, self._worker_init_fn, i, self._num_workers))
    w.daemon = True
    # NB: Process.start() actually take some time as it needs to
    # start a process and pass the arguments over via a pipe.
    # Therefore, we only add a worker to self._workers list after
    # it started, so that we do not call .join() if program dies
    # before it starts, and __del__ tries to join but will get:
    # AssertionError: can only join a started process.
    w.start()
    self._index_queues.append(index_queue)
    self._workers.append(w)
    self._workers_status.append(True)

    这里子进程的目标函数是_utils.worker._worker_loop(in worker.py),在该函数中完成从index_queue中获取index并利用fetcherdataset中读取相应的数据整套操作:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
    auto_collation, collate_fn, drop_last, seed, init_fn, worker_id,
    num_workers):
    ......
    while watchdog.is_alive():
    try:
    r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
    except queue.Empty:
    continue
    ......
    idx, index = r
    ......
    try:
    data = fetcher.fetch(index)
    data_queue.put((idx, data))
    ......

    可见,子进程从主进程传递的index_queue中获取index,完成数据读取操作后,将数据放入data_queue后返回给主进程。

    init的最后,启动预抓取循环,开始通过_try_put_index函数获取index

    1
    2
    3
    4
    # prime the prefetch loop
    # why 2?
    for _ in range(2 * self._num_workers):
    self._try_put_index()
  • _try_put_index函数调用基类的_next_index函数获取相应的index,找到当前alive的第一个workerindex传递到其index_queue中:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    def _try_put_index(self):
    assert self._tasks_outstanding < 2 * self._num_workers
    try:
    index = self._next_index()
    except StopIteration:
    return
    for _ in range(self._num_workers): # find the next active worker, if any
    worker_queue_idx = next(self._worker_queue_idx_cycle)
    if self._workers_status[worker_queue_idx]:
    break
    else:
    # not found (i.e., didn't break)
    return

    self._index_queues[worker_queue_idx].put((self._send_idx, index))
    self._task_info[self._send_idx] = (worker_queue_idx,)
    self._tasks_outstanding += 1
    self._send_idx += 1
  • 当外界通过基函数调用该子类的_next_data函数时,开始循环执行以下操作:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    def _next_data(self):
    while True:
    # If the worker responsible for `self._rcvd_idx` has already ended
    # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
    # we try to advance `self._rcvd_idx` to find the next valid index.
    #
    # This part needs to run in the loop because both the `self._get_data()`
    # call and `_IterableDatasetStopIteration` check below can mark
    # extra worker(s) as dead.
    while self._rcvd_idx < self._send_idx:
    info = self._task_info[self._rcvd_idx]
    worker_id = info[0]
    if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
    break
    del self._task_info[self._rcvd_idx]
    self._rcvd_idx += 1
    else:
    # no valid `self._rcvd_idx` is found (i.e., didn't break)
    self._shutdown_workers()
    raise StopIteration

    # Now `self._rcvd_idx` is the batch index we want to fetch

    # Check if the next sample has already been generated
    if len(self._task_info[self._rcvd_idx]) == 2:
    data = self._task_info.pop(self._rcvd_idx)[1]
    return self._process_data(data)

    assert not self._shutdown and self._tasks_outstanding > 0
    idx, data = self._get_data()
    self._tasks_outstanding -= 1
    ......
    if idx != self._rcvd_idx:
    # store out-of-order samples
    self._task_info[idx] += (data,)
    else:
    del self._task_info[idx]
    return self._process_data(data)

    这个最重要的函数的主要流程如下:

    1. self._rcvd_idx < self._send_idx时,表明仍有已发送序号的数据未被接收,因此循环去寻找,此时只有三种情况:

      a. 已发送的数据正在被worker处理(self._workers_status[worker_id]),此时break去接收数据;

      b. 已发送的数据已经到达(len(info) == 2),此时break去处理数据;

      c. 已发送的数据丢失,因此该idx对应的数据无法被接收,删除该任务信息,去接收下一个序号的数据。

      当所有数据都被接收完,表示数据传输完成,关闭所有workers

    2. 在循环外首先处理1b情况,读取相应数据并将其从list中弹出,调用_process_data函数:

      1
      2
      3
      4
      5
      6
      def _process_data(self, data):
      self._rcvd_idx += 1
      self._try_put_index()
      if isinstance(data, ExceptionWrapper):
      data.reraise()
      return data

      首先rcvd_idx自增表示数据完成接收,去接收下一数据,然后通过_try_put_index函数发送一个序号,并返回此次获取的数据,同时此数据也作为_next_data的返回值逐步被外界接收。

    3. 再处理1a情况,调用_get_data函数获取当前可以获取的数据,_get_data函数通过调用_try_get_data函数从而获取_data_queue中到达的数据,返回。

      由于在发送index时同时发送了send_idx,因此进行一次判定来判断获取数据是否为期望数据:

      • send_idxrcvd_idx相同,是期望数据,删除信息,调用_process_data函数处理;

      • send_idxrcvd_idx不同,不是期望数据,将该idx对应的数据放回,等待匹配。

        假定前n次完美匹配,那么该idxsend_idx)一定大于rcvd_idx,在未来一定可以成功匹配。

References